{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "from os import mkdir\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "sys.path.append('..')\n",
    "from config import cfg\n",
    "from data import make_data_loader\n",
    "from engine.trainer import do_train\n",
    "from modeling import build_model\n",
    "from solver import make_optimizer, WarmupMultiStepLR\n",
    "from layers import make_loss\n",
    "\n",
    "from utils.logger import setup_logger\n",
    "from utils.feats_pca import feats_map_pca_projection,feats_pca_projection\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import cv2\n",
    "from imageio_ffmpeg import write_frames\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "torch.cuda.set_device(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = 'sport_1_training'\n",
    "epoch = 908\n",
    "para_file = 'nr_model_%d.pth' % epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.merge_from_file(os.path.join(model_path,'configs.yml'))\n",
    "#cfg.INPUT.SIZE_TRAIN = [1000,750]\n",
    "#cfg.INPUT.SIZE_TEST = [1000,750]\n",
    "cfg.SOLVER.IMS_PER_BATCH = 1\n",
    "cfg.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = SummaryWriter(log_dir=os.path.join(model_path,'tensorboard_test'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader, dataset = make_data_loader(cfg, is_train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = build_model(cfg, isTrain = False)\n",
    "model.load_state_dict(torch.load(os.path.join(model_path,para_file),map_location='cpu'))\n",
    "model.eval()\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_maps = []\n",
    "tars = []\n",
    "\n",
    "\n",
    "i = 0\n",
    "for batch in test_loader:\n",
    "    in_points = batch[1].cuda()\n",
    "    K = batch[2].cuda()\n",
    "    T = batch[3].cuda()\n",
    "    near_far_max_splatting_size = batch[5]\n",
    "    num_points = batch[4]\n",
    "    point_indexes = batch[0]\n",
    "    target = batch[7].cuda()\n",
    "    inds = batch[6].cuda()\n",
    "    rgbs = batch[8].cuda()\n",
    "    \n",
    "    res,depth,features,dir_in_world,rgb,m_point_features = model(in_points, K, T,\n",
    "                        near_far_max_splatting_size, num_points, rgbs, inds)\n",
    "    \n",
    "    i = i+1\n",
    "    \n",
    "    if (i>0):\n",
    "        break\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(res.detach().cpu()[0].permute(1,2,0)[:,:,0:3])\n",
    "print(target.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,in_points, _,_,_,_,rgbs =  dataset.__getitem__(1*56)\n",
    "center = torch.mean(in_points,dim=0).cpu()\n",
    "up = -torch.mean(dataset.datasets[0].Ts[:,0:3,1],dim =0)\n",
    "up = up / torch.norm(up)\n",
    "radius = torch.norm(dataset.datasets[0].Ts[0,0:3,3] - center) * 1.3\n",
    "\n",
    "center = center \n",
    "\n",
    "v = torch.tensor([0,0,-1], dtype=torch.float32)\n",
    "v = v - up.dot(v)*up\n",
    "v = v / torch.norm(v)\n",
    "\n",
    "K[:,0,2] = 400\n",
    "K[:,1,2] = 300\n",
    "\n",
    "s_pos = center - v * radius + up*radius*0.1\n",
    "\n",
    "center = center.numpy()\n",
    "up = up.numpy()\n",
    "radius = radius.item()\n",
    "s_pos = s_pos.numpy()\n",
    "\n",
    "lookat = center - s_pos\n",
    "lookat = lookat/np.linalg.norm(lookat)\n",
    "\n",
    "xaxis = np.cross(lookat, up)\n",
    "xaxis = xaxis / np.linalg.norm(xaxis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import math\n",
    "def rotate(angle):\n",
    "    res = np.array([ [math.cos(angle), 0, math.sin(-angle)],[0,1,0],[ math.sin(-angle),0, math.cos(angle)]])\n",
    "    return res\n",
    "\n",
    "def rodrigues_rotation_matrix(axis, theta):\n",
    "    axis = np.asarray(axis)\n",
    "    theta = np.asarray(theta)\n",
    "    axis = axis/math.sqrt(np.dot(axis, axis))\n",
    "    a = math.cos(theta/2.0)\n",
    "    b, c, d = -axis*math.sin(theta/2.0)\n",
    "    aa, bb, cc, dd = a*a, b*b, c*c, d*d\n",
    "    bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d\n",
    "    return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)],\n",
    "                     [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],\n",
    "                     [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])\n",
    "\n",
    "index = 0\n",
    "dx = 1\n",
    "if not os.path.exists(os.path.join(model_path,'vis_%d'%epoch)):\n",
    "    os.mkdir(os.path.join(model_path,'vis_%d'%epoch))\n",
    "if not os.path.exists(os.path.join(model_path,'vis_mask_%d'%epoch)):\n",
    "    os.mkdir(os.path.join(model_path,'vis_mask_%d'%epoch))\n",
    "if not os.path.exists(os.path.join(model_path,'vis_depth_%d'%epoch)):\n",
    "    os.mkdir(os.path.join(model_path,'vis_depth_%d'%epoch))\n",
    "    \n",
    "if not os.path.exists(os.path.join(model_path,'vis_compos_%d'%epoch)):\n",
    "    os.mkdir(os.path.join(model_path,'vis_compos_%d'%epoch))\n",
    "sKs = []\n",
    "sTs = []\n",
    "frames_id = []\n",
    "    \n",
    "for i in range(360):\n",
    "\n",
    "    \n",
    "    \n",
    "    _,in_points, _,_,_,_,rgbs =  dataset.__getitem__(index*dataset.datasets[0].Ts.size(0))\n",
    "    \n",
    "    num_points = torch.Tensor([in_points.size(0)])\n",
    "    rgbs = rgbs.cuda()\n",
    "    print(num_points)\n",
    "    \n",
    "    in_points = in_points.cuda()\n",
    "    \n",
    "    if index >=dataset.datasets[0].frame_num-1:\n",
    "        dx = -1\n",
    "    \n",
    "    angle = 3.1415926*2*i/360.0\n",
    "    \n",
    "    ii = i % 100\n",
    "    angle = 3.1415926*(ii-50)/360.0\n",
    "    #angle= 0 \n",
    "    \n",
    "    pos = s_pos - center\n",
    "    pos = rodrigues_rotation_matrix(up,-angle).dot(pos) \n",
    "    \n",
    "    pos = pos + center\n",
    "    \n",
    "    print('pos:',pos)\n",
    "    \n",
    "    \n",
    "    lookat = center - pos\n",
    "    lookat = lookat/np.linalg.norm(lookat)\n",
    "    \n",
    "    xaxis = np.cross(lookat, up)\n",
    "    xaxis = xaxis / np.linalg.norm(xaxis)\n",
    "    \n",
    "    yaxis = -np.cross(xaxis,lookat)\n",
    "    yaxis = yaxis/np.linalg.norm(yaxis)\n",
    "    \n",
    "    nR = np.array([xaxis,yaxis,lookat, pos]).T\n",
    "    nR = np.concatenate([nR,np.array([[0,0,0,1]])])\n",
    "    \n",
    "    sTs.append(nR)\n",
    "    sKs.append(K[0].cpu().numpy())\n",
    "    frames_id.append(index)\n",
    "    \n",
    "    T[0,:,:] = torch.Tensor(nR).cuda()\n",
    "    with torch.no_grad():\n",
    "        res,depth,features,dir_in_world,rgb,m_point_features = model(in_points, K, T,\n",
    "                            near_far_max_splatting_size, num_points, rgbs, inds)\n",
    "    \n",
    "    img_t = res.detach().cpu()[0]\n",
    "    \n",
    "    mask_t = img_t[3:4,:,:]\n",
    "    \n",
    "    img_t[0:3,:,:] = img_t[0:3,:,:]\n",
    "    img = cv2.cvtColor(img_t.permute(1,2,0).numpy()*255.0,cv2.COLOR_BGR2RGB)\n",
    "    mask = mask_t.permute(1,2,0).numpy()*255.0\n",
    "    img_depth = depth.detach().cpu()[0][0].numpy()\n",
    "    img_depth = img_depth *255/ np.max(img_depth)\n",
    "    img_depth = cv2.cvtColor(img_depth,cv2.COLOR_GRAY2RGB)\n",
    "   \n",
    "    cv2.imwrite(os.path.join(model_path,'vis_%d/img_%04d.jpg'%(epoch,i)),img)\n",
    "    \n",
    "    img_t[0:3,:,:] = img_t[0:3,:,:]*mask_t.repeat(3,1,1)\n",
    "    img_t[0:3,:,:][mask_t.repeat(3,1,1)<0.95] = 1.0\n",
    "    img = cv2.cvtColor(img_t.permute(1,2,0).numpy()*255.0,cv2.COLOR_BGR2RGB)\n",
    "    \n",
    "    cv2.imwrite(os.path.join(model_path,'vis_compos_%d/img_%04d.jpg'%(epoch,i)),img)\n",
    "    cv2.imwrite(os.path.join(model_path,'vis_mask_%d/img_%04d.jpg'%(epoch,i)  ),mask)\n",
    "    cv2.imwrite(os.path.join(model_path,'vis_depth_%d/img_%04d.jpg'%(epoch,i)  ),img_depth)\n",
    "    \n",
    "\n",
    "    \n",
    "    del res\n",
    "    del depth\n",
    "    del features\n",
    "    del dir_in_world\n",
    "    del rgb\n",
    "    del img\n",
    "    torch.cuda.empty_cache()\n",
    "    print(i,'/360')\n",
    "    \n",
    "    \n",
    "\n",
    "with open(os.path.join(model_path,'Intrinsic_%d.inf'%epoch), 'w') as f:\n",
    "    for i,camk in enumerate(sKs):\n",
    "        f.write('%d\\n'%i)\n",
    "        f.write('%f %f %f\\n %f %f %f\\n %f %f %f\\n' % tuple(camk.reshape(9).tolist()))\n",
    "        f.write('\\n')\n",
    "        \n",
    "        \n",
    "with open(os.path.join(model_path,'CamPose_%d.inf' %epoch), 'w') as f:\n",
    "    for i,camT in enumerate(sTs):\n",
    "        A = camT[0:3,:]\n",
    "        tmp = np.concatenate( [A[0:3,2].T, A[0:3,0].T,A[0:3,1].T,A[0:3,3].T])\n",
    "        f.write('%f %f %f %f %f %f %f %f %f %f %f %f\\n' % tuple(tmp.tolist()))\n",
    "    \n",
    "with open(os.path.join(model_path,'frames_%d.inf' %epoch), 'w') as f:\n",
    "    for i,ids in enumerate(frames_id):\n",
    "        f.write('%d\\n' % int(ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
